Dependencies

import numpy as np
import pandas as pd
import seaborn as sns
import albumentations as A
import matplotlib.pyplot as plt
import os, gc, cv2, random, re
import warnings, math, sys, json, pprint, pdb

import tensorflow as tf
from tensorflow.keras import backend as K
import tensorflow_hub as hub

from sklearn.model_selection import train_test_split

Setup

DEVICE = 'TPU' #@param ["None", "'GPU'", "'TPU'"] {type:"raw", allow-input: true}

if DEVICE == "TPU":
    print("connecting to TPU...")
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        print('Running on TPU ', tpu.master())
    except ValueError:
        print("Could not connect to TPU")
        tpu = None

    if tpu:
        try:
            print("initializing  TPU ...")
            tf.config.experimental_connect_to_cluster(tpu)
            tf.tpu.experimental.initialize_tpu_system(tpu)
            strategy = tf.distribute.experimental.TPUStrategy(tpu)
            print("TPU initialized")
        except _:
            print("failed to initialize TPU")
    else:
        DEVICE = "GPU"

if DEVICE != "TPU":
    print("Using default strategy for CPU and single GPU")
    strategy = tf.distribute.get_strategy()

if DEVICE == "GPU":
    print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
    

AUTOTUNE = tf.data.experimental.AUTOTUNE
REPLICAS = strategy.num_replicas_in_sync
print(f'REPLICAS: {REPLICAS}')
connecting to TPU...
Running on TPU  grpc://10.118.221.90:8470
initializing  TPU ...
INFO:tensorflow:Initializing the TPU system: grpc://10.118.221.90:8470
INFO:tensorflow:Initializing the TPU system: grpc://10.118.221.90:8470
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Finished initializing TPU system.
INFO:tensorflow:Finished initializing TPU system.
WARNING:absl:`tf.distribute.experimental.TPUStrategy` is deprecated, please use  the non experimental symbol `tf.distribute.TPUStrategy` instead.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
TPU initialized
REPLICAS: 8
def seed_everything(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'

def is_colab():
    return 'google.colab' in str(get_ipython())

Tip: Adding seed helps reproduce results. Setting debug parameter wil run the model on smaller number of epochs to validate the architecture.
#@title Debugger { run: "auto" }
SEED = 16
DEBUG = True #@param {type:"boolean"}
TRAIN = True #@param {type:"boolean"}
INFERENCE = True #@param {type:"boolean"}

IS_COLAB = is_colab()

warnings.simplefilter('ignore')
seed_everything(SEED)

print(f"Using TensorFlow v{tf.__version__}")
Using TensorFlow v2.4.0
if IS_COLAB:
    from google.colab import drive
    drive.mount('/content/gdrive', force_remount=True)
Mounted at /content/gdrive
project_name = 'tpu-getting-started'
root_path  = '/content/gdrive/MyDrive/' if IS_COLAB else '/'
input_path = f'{root_path}kaggle/input/{project_name}/'
working_path = f'{input_path}working/' if IS_COLAB else '/kaggle/working/'
os.makedirs(working_path, exist_ok=True)
os.chdir(working_path)
os.listdir(input_path)
['working']

Loading data

TFRecord basics

GCS_PATTERN = 'gs://flowers-public/*/*.jpg'
GCS_OUTPUT = 'gs://flowers-public/tfrecords-jpeg-192x192-2/flowers'
SHARDS = 16
TARGET_SIZE = [192, 192]
CLASSES = [b'daisy', b'dandelion', b'roses', b'sunflowers', b'tulips']

Read images and labels

def decode_image_and_label(filename):
    bits = tf.io.read_file(filename)
    image = tf.image.decode_jpeg(bits)
    label = tf.strings.split(tf.expand_dims(filename, axis=-1), sep='/')
    #label = tf.strings.split(filename, sep='/')
    label = label.values[-2]
    label = tf.cast((CLASSES==label), tf.int8)
    return image, label
 filenames = tf.data.Dataset.list_files(GCS_PATTERN, seed=16)
for x in filenames.take(10): print(x)
tf.Tensor(b'gs://flowers-public/tulips/251811158_75fa3034ff.jpg', shape=(), dtype=string)
tf.Tensor(b'gs://flowers-public/daisy/506348009_9ecff8b6ef.jpg', shape=(), dtype=string)
tf.Tensor(b'gs://flowers-public/daisy/2019064575_7656b9340f_m.jpg', shape=(), dtype=string)
tf.Tensor(b'gs://flowers-public/tulips/8713396140_5af8136136.jpg', shape=(), dtype=string)
tf.Tensor(b'gs://flowers-public/roses/218630974_5646dafc63_m.jpg', shape=(), dtype=string)
tf.Tensor(b'gs://flowers-public/roses/410421672_563550467c.jpg', shape=(), dtype=string)
tf.Tensor(b'gs://flowers-public/tulips/8614237582_74417799f4_m.jpg', shape=(), dtype=string)
tf.Tensor(b'gs://flowers-public/dandelion/8797114213_103535743c_m.jpg', shape=(), dtype=string)
tf.Tensor(b'gs://flowers-public/dandelion/11296320473_1d9261ddcb.jpg', shape=(), dtype=string)
tf.Tensor(b'gs://flowers-public/dandelion/14554897292_b3e30e52f2.jpg', shape=(), dtype=string)
ds0 = filenames.map(decode_image_and_label, num_parallel_calls=AUTOTUNE)
def show_images(ds):
    _,axs = plt.subplots(3,3,figsize=(16,16))
    for ((x, y), ax) in zip(ds.take(9), axs.flatten()):
        ax.imshow(x.numpy().astype(np.uint8))
        ax.set_title(np.argmax(y))
        ax.axis('off')
 show_images(ds0)

Resize and crop images to common size

No need to study the code in this cell. It's only image resizing.

def resize_and_crop_image(image, label):
    # Resize and crop using "fill" algorithm:
    # always make sure the resulting image
    # is cut out from the source image so that
    # it fills the TARGET_SIZE entirely with no
    # black bars and a preserved aspect ratio.
    w = tf.shape(image)[0]
    h = tf.shape(image)[1]
    tw = TARGET_SIZE[1]
    th = TARGET_SIZE[0]
    resize_crit = (w * th) / (h * tw)
    image = tf.cond(resize_crit < 1,
                    lambda: tf.image.resize(image, [w*tw/w, h*tw/w]), # if true
                    lambda: tf.image.resize(image, [w*th/h, h*th/h])  # if false
                   )
    nw = tf.shape(image)[0]
    nh = tf.shape(image)[1]
    image = tf.image.crop_to_bounding_box(image, (nw - tw) // 2, (nh - th) // 2, tw, th)
    return image, label
ds1 = ds0.map(resize_and_crop_image, num_parallel_calls=AUTOTUNE) 
show_images(ds1)

Speed test: too slow

Google Cloud Storage is capable of great throughput but has a per-file access penalty. Run the cell below and see that throughput is around 8 images per second. That is too slow. Training on thousands of individual files will not work. We have to use the TFRecord format to group files together.

%%time
for image,label in ds1.batch(8).take(10):
    print("Image batch shape {} {}".format(
        image.numpy().shape,
        [np.argmax(lbl) for lbl in label.numpy()]))
Image batch shape (8, 192, 192, 3) [2, 1, 4, 1, 4, 3, 3, 0]
Image batch shape (8, 192, 192, 3) [2, 3, 0, 1, 4, 3, 2, 4]
Image batch shape (8, 192, 192, 3) [1, 3, 1, 3, 0, 1, 0, 0]
Image batch shape (8, 192, 192, 3) [2, 1, 4, 0, 3, 3, 2, 3]
Image batch shape (8, 192, 192, 3) [2, 1, 3, 4, 1, 4, 4, 1]
Image batch shape (8, 192, 192, 3) [1, 1, 2, 4, 1, 4, 3, 2]
Image batch shape (8, 192, 192, 3) [3, 3, 3, 4, 4, 2, 1, 1]
Image batch shape (8, 192, 192, 3) [4, 2, 0, 2, 3, 0, 0, 1]
Image batch shape (8, 192, 192, 3) [2, 1, 1, 1, 4, 2, 0, 0]
Image batch shape (8, 192, 192, 3) [4, 2, 3, 1, 1, 2, 2, 2]
CPU times: user 1.56 s, sys: 93.6 ms, total: 1.65 s
Wall time: 5.71 s

Recompress the images

The bandwidth savings outweight the decoding CPU cost

def recompress_image(image, label):
    height = tf.shape(image)[0]
    width = tf.shape(image)[1]
    image = tf.cast(image, tf.uint8)
    image = tf.image.encode_jpeg(image, optimize_size=True, chroma_downsampling=False)
    return image, label, height, width
IMAGE_SIZE = len(tf.io.gfile.glob(GCS_PATTERN))
SHARD_SIZE = math.ceil(1.0 * IMAGE_SIZE / SHARDS)
ds2 = ds1.map(recompress_image, num_parallel_calls=AUTOTUNE)
ds2 = ds2.batch(SHARD_SIZE) # sharding: there will be one "batch" of images per file

Why TFRecords?

TPUs have eight cores which act as eight independent workers. We can get data to each core more efficiently by splitting the dataset into multiple files or shards. This way, each core can grab an independent part of the data as it needs. The most convenient kind of file to use for sharding in TensorFlow is a TFRecord. A TFRecord is a binary file that contains sequences of byte-strings. Data needs to be serialized (encoded as a byte-string) before being written into a TFRecord. The most convenient way of serializing data in TensorFlow is to wrap the data with tf.Example. This is a record format based on Google's protobufs but designed for TensorFlow. It's more or less like a dict with some type annotations

x = tf.constant([[1,2], [3, 4]], dtype=tf.uint8)
print(x)
tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=uint8)
x_in_bytes = tf.io.serialize_tensor(x)
print(x_in_bytes)
tf.Tensor(b'\x08\x04\x12\x08\x12\x02\x08\x02\x12\x02\x08\x02"\x04\x01\x02\x03\x04', shape=(), dtype=string)
print(tf.io.parse_tensor(x_in_bytes, out_type=tf.uint8))
tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=uint8)

A TFRecord is a sequence of bytes, so we have to turn our data into byte-strings before it can go into a TFRecord. We can use tf.io.serialize_tensor to turn a tensor into a byte-string and tf.io.parse_tensor to turn it back. It's important to keep track of your tensor's datatype (in this case tf.uint8) since you have to specify it when parsing the string back to a tensor again

Write dataset to TFRecord files

#    return tf.train.Feature(bytes_list=tf.train.BytesList(value=list_of_bytestrings))
#
#def _int_feature(list_of_ints): # int64
#    return tf.train.Feature(int64_list=tf.train.Int64List(value=list_of_ints))
#
#def _float_feature(list_of_floats): # float32
#    return tf.train.Feature(float_list=tf.train.FloatList(value=list_of_floats))
#
#def to_tfrecord(tfrec_filewriter, img_bytes, label, height, width):
#    id = np.argmax(np.array(CLASSES)==label)
#    one_hot = np.eye(len(CLASSES))[id]
#    feature = {
#      "image": _bytestring_feature([img_bytes]), # one image in the list
#      "id":    _int_feature([id]),               # one class in the list
#      "label": _bytestring_feature([label]),     # fixed length (1) list of strings, the text label
#      "size" : _int_feature([height, width]),    # fixed length (2) list of ints
#      "one_hot": _float_feature(one_hot.tolist())# variable length  list of floats, n=len(CLASSES)
#    }
#    return tf.train.Example(features=tf.train.Features(feature=feature))
#for shard_id, (image, label, height, width) in ds2.enumerate():
#    shard_size = image.numpy().shape[0]
#    filename = GCS_OUTPUT + "{:02d}-{}tfrec".format(shard_id, shard_size)
#
#    with tf.io.TFRecordWriter(filename) as outfile:
#        for i in range(shard_size):
#            example = to_tfrecord(out_file,
#                                  image.numpy()[i],
#                                  label.numpy()[i],
#                                  height.numpy()[i],
#                                  width.numpy()[i])
#            out_file.write(example.SerializeToString())
#        print("Wrote file {} containing {} records".format(filename, shard_size))

Note: Need to check if I can upload the tfrecords to the Kaggle for free
GCS_OUTPUT
'gs://flowers-public/tfrecords-jpeg-192x192-2/flowers'

Read from TFRecord Dataset

def read_tfrecord(example):
    features = {
        "image": tf.io.FixedLenFeature([], tf.string),  # tf.string = bytestring (not text string)
        "class": tf.io.FixedLenFeature([], tf.int64),   # shape [] means scalar
        
        # additional (not very useful) fields to demonstrate TFRecord writing/reading of different types of data
        "label":         tf.io.FixedLenFeature([], tf.string),  # one bytestring
        "size":          tf.io.FixedLenFeature([2], tf.int64),  # two integers
        "one_hot_class": tf.io.VarLenFeature(tf.float32)        # a certain number of floats
    }
    # decode the TFRecord
    example = tf.io.parse_single_example(example, features)
    
    # FixedLenFeature fields are now ready to use: exmple['size']
    # VarLenFeature fields require additional sparse_to_dense decoding
    
    image = tf.image.decode_jpeg(example['image'], channels=3)
    image = tf.reshape(image, [*TARGET_SIZE, 3])
    
    class_num = example['class']
    
    label  = example['label']
    height = example['size'][0]
    width  = example['size'][1]
    one_hot_class = tf.sparse.to_dense(example['one_hot_class'])
    return image, class_num, label, height, width, one_hot_class
option_no_order = tf.data.Options()
option_no_order.experimental_deterministic = False

filenames = tf.io.gfile.glob(GCS_OUTPUT + "*tfrec")
ds3 = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
ds3 = (ds3.with_options(option_no_order)
         .map(read_tfrecord, num_parallel_calls=AUTOTUNE)
         .shuffle(30))
ds3_to_show = ds3.map(lambda image, id, label, height, width, one_hot: (image, label))
show_images(ds3_to_show)

Speed test: fast

Loading training data is not a bottleneck anymore

%%time
for image, class_num, label, height, width, one_hot_class in ds3.batch(8).take(10):
    print("Image batch shape {} {}".format(
        image.numpy().shape,
        [lbl.decode('utf8') for lbl in label.numpy()]))
Image batch shape (8, 192, 192, 3) ['roses', 'roses', 'dandelion', 'tulips', 'sunflowers', 'dandelion', 'sunflowers', 'tulips']
Image batch shape (8, 192, 192, 3) ['sunflowers', 'dandelion', 'tulips', 'tulips', 'sunflowers', 'sunflowers', 'roses', 'tulips']
Image batch shape (8, 192, 192, 3) ['daisy', 'tulips', 'roses', 'roses', 'tulips', 'daisy', 'sunflowers', 'sunflowers']
Image batch shape (8, 192, 192, 3) ['roses', 'roses', 'sunflowers', 'sunflowers', 'tulips', 'sunflowers', 'tulips', 'tulips']
Image batch shape (8, 192, 192, 3) ['daisy', 'dandelion', 'dandelion', 'daisy', 'tulips', 'dandelion', 'tulips', 'tulips']
Image batch shape (8, 192, 192, 3) ['roses', 'dandelion', 'roses', 'daisy', 'dandelion', 'dandelion', 'dandelion', 'dandelion']
Image batch shape (8, 192, 192, 3) ['dandelion', 'sunflowers', 'daisy', 'sunflowers', 'sunflowers', 'sunflowers', 'dandelion', 'daisy']
Image batch shape (8, 192, 192, 3) ['roses', 'daisy', 'dandelion', 'dandelion', 'sunflowers', 'daisy', 'dandelion', 'dandelion']
Image batch shape (8, 192, 192, 3) ['sunflowers', 'sunflowers', 'roses', 'dandelion', 'sunflowers', 'dandelion', 'dandelion', 'dandelion']
Image batch shape (8, 192, 192, 3) ['daisy', 'sunflowers', 'tulips', 'roses', 'roses', 'roses', 'sunflowers', 'sunflowers']
CPU times: user 219 ms, sys: 34.1 ms, total: 253 ms
Wall time: 212 ms

Constructing data

GCS_DS_PATH = 'gs://kds-e93303da9a97ef8fd254ceb5e9ed104470f247527dd45aba9685bdf5'
GCS_PATH = GCS_DS_PATH + '/tfrecords-jpeg-512x512'
IMG_SIZE = [512, 512]
BATCH_SIZE = 16 * strategy.num_replicas_in_sync
CLASSES = ['pink primrose',    'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',     'wild geranium',     'tiger lily',           'moon orchid',              'bird of paradise', 'monkshood',        'globe thistle',         # 00 - 09
           'snapdragon',       "colt's foot",               'king protea',      'spear thistle', 'yellow iris',       'globe-flower',         'purple coneflower',        'peruvian lily',    'balloon flower',   'giant white arum lily', # 10 - 19
           'fire lily',        'pincushion flower',         'fritillary',       'red ginger',    'grape hyacinth',    'corn poppy',           'prince of wales feathers', 'stemless gentian', 'artichoke',        'sweet william',         # 20 - 29
           'carnation',        'garden phlox',              'love in the mist', 'cosmos',        'alpine sea holly',  'ruby-lipped cattleya', 'cape flower',              'great masterwort', 'siam tulip',       'lenten rose',           # 30 - 39
           'barberton daisy',  'daffodil',                  'sword lily',       'poinsettia',    'bolero deep blue',  'wallflower',           'marigold',                 'buttercup',        'daisy',            'common dandelion',      # 40 - 49
           'petunia',          'wild pansy',                'primula',          'sunflower',     'lilac hibiscus',    'bishop of llandaff',   'gaura',                    'geranium',         'orange dahlia',    'pink-yellow dahlia',    # 50 - 59
           'cautleya spicata', 'japanese anemone',          'black-eyed susan', 'silverbush',    'californian poppy', 'osteospermum',         'spring crocus',            'iris',             'windflower',       'tree poppy',            # 60 - 69
           'gazania',          'azalea',                    'water lily',       'rose',          'thorn apple',       'morning glory',        'passion flower',           'lotus',            'toad lily',        'anthurium',             # 70 - 79
           'frangipani',       'clematis',                  'hibiscus',         'columbine',     'desert-rose',       'tree mallow',          'magnolia',                 'cyclamen ',        'watercress',       'canna lily',            # 80 - 89
           'hippeastrum ',     'bee balm',                  'pink quill',       'foxglove',      'bougainvillea',     'camellia',             'mallow',                   'mexican petunia',  'bromelia',         'blanket flower',        # 90 - 99
           'trumpet creeper',  'blackberry lily',           'common tulip',     'wild rose']

option_no_order = tf.data.Options()
option_no_order.experimental_deterministic = False
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
    image = tf.reshape(image, [*IMG_SIZE, 3]) # explicit size needed for TPU
    return image
    
def collate_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = tf.cast(example['class'], tf.int32)
    return image, label # returns a dataset of (image, label) pairs

def process_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "id": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
        # class is missing, this competitions's challenge is to predict flower classes for the test dataset
    }
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum = example['id']
    return image, idnum

def count_data_items(filenames):
    # the number of data items is written in the name of the .tfrec
    # files, i.e. flowers00-230.tfrec = 230 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)
train_filenames = tf.io.gfile.glob(GCS_PATH + '/train/*.tfrec')
valid_filenames = tf.io.gfile.glob(GCS_PATH + '/val/*.tfrec')
test_filenames  = tf.io.gfile.glob(GCS_PATH + '/test/*.tfrec') 
train_ds = tf.data.TFRecordDataset(train_filenames, num_parallel_reads=AUTOTUNE)
train_ds = (train_ds
            .map(collate_labeled_tfrecord)
            .repeat()
            .shuffle(2048)
            .batch(BATCH_SIZE)
            .prefetch(AUTOTUNE))
valid_ds = tf.data.TFRecordDataset(valid_filenames, num_parallel_reads=AUTOTUNE)
valid_ds = (valid_ds
            .with_options(option_no_order)
            .map(collate_labeled_tfrecord)
            .batch(BATCH_SIZE)
            .cache()
            .prefetch(AUTOTUNE))
test_ds = tf.data.TFRecordDataset(test_filenames, num_parallel_reads=AUTOTUNE)
test_ds = (test_ds
            .with_options(option_no_order)
            .map(process_unlabeled_tfrecord)
            .batch(BATCH_SIZE)
            .prefetch(AUTOTUNE))
BASE_MODEL, IMG_SIZE = ('efficientnet_b3', 300) #@param ["('efficientnet_b3', 300)", "('efficientnet_b4', 380)", "('efficientnet_b2', 260)"] {type:"raw", allow-input: true}
BATCH_SIZE = 32 #@param {type:"integer"}
IMG_SIZE = (IMG_SIZE, IMG_SIZE) #@param ["(IMG_SIZE, IMG_SIZE)", "(512,512)"] {type:"raw"}
print("Using {} with input size {}".format(BASE_MODEL, IMG_SIZE))
Using efficientnet_b3 with input size (300, 300)

Batch augmentation

data_augmentation = tf.keras.Sequential(
    [
     tf.keras.layers.experimental.preprocessing.RandomCrop(*IMG_SIZE),
     tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
     tf.keras.layers.experimental.preprocessing.RandomRotation(0.25),
     tf.keras.layers.experimental.preprocessing.RandomZoom((-0.2, 0)),
     tf.keras.layers.experimental.preprocessing.RandomContrast((0.2,0.2))
    ]
)
func = lambda x,y: (data_augmentation(x), y)
x = (train_ds
     .take(1)
     .map(func, num_parallel_calls=AUTOTUNE))

Building a model

Now we're ready to create a neural network for classifying images! We'll use what's known as transfer learning. With transfer learning, you reuse part of a pretrained model to get a headstart on a new dataset.

For this tutorial, we'll to use a model called VGG16 pretrained on ImageNet. Later, you might want to experiment with other models included with Keras. (Xception wouldn't be a bad choice.)

The distribution strategy we created earilier contains a context manager, straategy.scope. This context manager tells TensorFlow how to divide the work of training among the eight TPU cores. When using TensorFlow with a TPU, it's important to define your model in strategy.sceop() context.

EPOCHS = 12
with strategy.scope():
    pretrained_model = tf.keras.applications.VGG16(
        weights='imagenet',
        include_top=False ,
        input_shape=[*IMG_SIZE, 3]
    )
    pretrained_model.trainable = False
    
    model = tf.keras.Sequential([
        # To a base pretrained on ImageNet to extract features from images...
        pretrained_model,
        # ... attach a new head to act as a classifier.
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(len(CLASSES), activation='softmax')
    ])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
58892288/58889256 [==============================] - 1s 0us/step

CosineDecay

Important: I always wanted to try the new CosineDecayRestarts function implemented in tf.keras as it seemed promising and I struggled to find the right settings (if there were any) for the ReduceLROnPlateau
EPOCHS = 12
STEPS = int(round(count_data_items(train_filenames)/BATCH_SIZE)) * EPOCHS
STEPS_PER_EPOCH = count_data_items(train_filenames) // BATCH_SIZE
schedule = tf.keras.experimental.CosineDecayRestarts(
    initial_learning_rate=1e-4,
    first_decay_steps=180
)
schedule.get_config()
{'alpha': 0.0,
 'first_decay_steps': 180,
 'initial_learning_rate': 0.0001,
 'm_mul': 1.0,
 'name': None,
 't_mul': 2.0}
x = [i for i in range(STEPS)]
y = [schedule(s) for s in range(STEPS)]
plt.plot(x, y)
[<matplotlib.lines.Line2D at 0x7f1ac39379e8>]
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        filepath='001_best_model.h5',
        monitor='val_loss',
        save_best_only=True),
    ]
    
model.compile(
    optimizer=tf.keras.optimizers.Adam(schedule),
    loss = 'sparse_categorical_crossentropy',
    metrics=['sparse_categorical_accuracy'],
)
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
vgg16 (Functional)           (None, 16, 16, 512)       14714688  
_________________________________________________________________
global_average_pooling2d (Gl (None, 512)               0         
_________________________________________________________________
dense (Dense)                (None, 104)               53352     
=================================================================
Total params: 14,768,040
Trainable params: 53,352
Non-trainable params: 14,714,688
_________________________________________________________________
history = model.fit(
    x=train_ds,
    validation_data=valid_ds,
    epochs=EPOCHS,
    steps_per_epoch=STEPS_PER_EPOCH,
    callbacks=callbacks
)
Epoch 1/12
99/99 [==============================] - 116s 880ms/step - loss: 4.6869 - sparse_categorical_accuracy: 0.0187 - val_loss: 4.4758 - val_sparse_categorical_accuracy: 0.0558
Epoch 2/12
99/99 [==============================] - 64s 652ms/step - loss: 4.4573 - sparse_categorical_accuracy: 0.0628 - val_loss: 4.3871 - val_sparse_categorical_accuracy: 0.0598
Epoch 3/12
99/99 [==============================] - 62s 629ms/step - loss: 4.3382 - sparse_categorical_accuracy: 0.0624 - val_loss: 4.2258 - val_sparse_categorical_accuracy: 0.0617
Epoch 4/12
99/99 [==============================] - 64s 645ms/step - loss: 4.2182 - sparse_categorical_accuracy: 0.0598 - val_loss: 4.1701 - val_sparse_categorical_accuracy: 0.0644
Epoch 5/12
99/99 [==============================] - 63s 641ms/step - loss: 4.1711 - sparse_categorical_accuracy: 0.0591 - val_loss: 4.1564 - val_sparse_categorical_accuracy: 0.0663
Epoch 6/12
99/99 [==============================] - 63s 643ms/step - loss: 4.1523 - sparse_categorical_accuracy: 0.0626 - val_loss: 4.1195 - val_sparse_categorical_accuracy: 0.0711
Epoch 7/12
99/99 [==============================] - 63s 639ms/step - loss: 4.1093 - sparse_categorical_accuracy: 0.0713 - val_loss: 4.0741 - val_sparse_categorical_accuracy: 0.0814
Epoch 8/12
99/99 [==============================] - 62s 634ms/step - loss: 4.0724 - sparse_categorical_accuracy: 0.0815 - val_loss: 4.0420 - val_sparse_categorical_accuracy: 0.0938
Epoch 9/12
99/99 [==============================] - 64s 646ms/step - loss: 4.0393 - sparse_categorical_accuracy: 0.0904 - val_loss: 4.0194 - val_sparse_categorical_accuracy: 0.0981
Epoch 10/12
99/99 [==============================] - 62s 631ms/step - loss: 4.0132 - sparse_categorical_accuracy: 0.1000 - val_loss: 4.0051 - val_sparse_categorical_accuracy: 0.1032
Epoch 11/12
99/99 [==============================] - 62s 630ms/step - loss: 4.0041 - sparse_categorical_accuracy: 0.1026 - val_loss: 3.9977 - val_sparse_categorical_accuracy: 0.1070
Epoch 12/12
99/99 [==============================] - 62s 635ms/step - loss: 3.9913 - sparse_categorical_accuracy: 0.1082 - val_loss: 3.9952 - val_sparse_categorical_accuracy: 0.1094
def plot_hist(hist):
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Loss over epochs')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'valid'], loc='best')
    plt.show()

def display_training_curves(training, validation, title, subplot):
    if subplot%10==1: # set up the subplots on the first call
        plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
        plt.tight_layout()
    ax = plt.subplot(subplot)
    ax.set_facecolor('#F8F8F8')
    ax.plot(training)
    ax.plot(validation)
    ax.set_title('model '+ title)
    ax.set_ylabel(title)
    #ax.set_ylim(0.28,1.05)
    ax.set_xlabel('epoch')
    ax.legend(['train', 'valid.'])

import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix

def display_confusion_matrix(cmat, score, precision, recall):
    plt.figure(figsize=(15,15))
    ax = plt.gca()
    ax.matshow(cmat, cmap='Reds')
    ax.set_xticks(range(len(CLASSES)))
    ax.set_xticklabels(CLASSES, fontdict={'fontsize': 7})
    plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor")
    ax.set_yticks(range(len(CLASSES)))
    ax.set_yticklabels(CLASSES, fontdict={'fontsize': 7})
    plt.setp(ax.get_yticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    titlestring = ""
    if score is not None:
        titlestring += 'f1 = {:.3f} '.format(score)
    if precision is not None:
        titlestring += '\nprecision = {:.3f} '.format(precision)
    if recall is not None:
        titlestring += '\nrecall = {:.3f} '.format(recall)
    if len(titlestring) > 0:
        ax.text(101, 1, titlestring, fontdict={'fontsize': 18, 'horizontalalignment':'right', 'verticalalignment':'top', 'color':'#804040'})
    plt.show()
    
def display_training_curves(training, validation, title, subplot):
    if subplot%10==1: # set up the subplots on the first call
        plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
        plt.tight_layout()
    ax = plt.subplot(subplot)
    ax.set_facecolor('#F8F8F8')
    ax.plot(training)
    ax.plot(validation)
    ax.set_title('model '+ title)
    ax.set_ylabel(title)
    #ax.set_ylim(0.28,1.05)
    ax.set_xlabel('epoch')
    ax.legend(['train', 'valid.'])
display_training_curves(
    history.history['loss'],
    history.history['val_loss'],
    'loss',
    211,
)
display_training_curves(
    history.history['sparse_categorical_accuracy'],
    history.history['val_sparse_categorical_accuracy'],
    'accuracy',
    212,
)
cmat_ds = tf.data.TFRecordDataset(valid_filenames, num_parallel_reads=AUTOTUNE)
cmat_ds = (cmat_ds
            .map(collate_labeled_tfrecord)
            .batch(BATCH_SIZE)
            .cache()
            .prefetch(AUTOTUNE))

images_ds = cmat_ds.map(lambda image, label: image)
labels_ds = cmat_ds.map(lambda image, label: label).unbatch()
cm_correct_labels = next(iter(labels_ds.batch(count_data_items(valid_filenames)))).numpy()
cm_probabilities = model.predict(images_ds)
cm_predictions = np.argmax(cm_probabilities, axis=-1)

labels = range(len(CLASSES))
cmat = confusion_matrix(
    cm_correct_labels,
    cm_predictions,
    labels=labels,
)
cmat = (cmat.T / cmat.sum(axis=1)).T # normalize

You might be familiar with metrics like F1-score or precision and recall. This cell will compute these metrics and display them with a plot of the confusion matrix. (These metrics are defined in the Scikit-learn module sklearn.metrics; we've imported them in the helper script for you.)

score = f1_score(
    cm_correct_labels,
    cm_predictions,
    labels=labels,
    average='macro',
)
precision = precision_score(
    cm_correct_labels,
    cm_predictions,
    labels=labels,
    average='macro',
)
recall = recall_score(
    cm_correct_labels,
    cm_predictions,
    labels=labels,
    average='macro',
)
display_confusion_matrix(cmat, score, precision, recall)

Visual Validation

It can also be helpful to look at some examples from the validation set and see what class your model predicted. This can help reveal patterns in the kinds of images your model has trouble with. This cell will set up the validation set to display 20 images at a time -- you can change this to display more or fewer, if you like.

ds = 

1% Better Everyday

reference


todos

  • Comment out the 1/255.0 in the image preprocessing
  • Reorganize the notebook structure

done